/**

Copyright (C) SYSTAP, LLC DBA Blazegraph 2006-2016.  All rights reserved.

Contact:
     SYSTAP, LLC DBA Blazegraph
     2501 Calvert ST NW #106
     Washington, DC 20008
     licenses@blazegraph.com

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; version 2 of the License.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/
/*
 * Created on Aug 30, 2011
 */

package com.bigdata.rdf.sparql.ast.optimizers;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.bigdata.bop.BOp;
import com.bigdata.bop.BOpUtility;
import com.bigdata.bop.IBindingSet;
import com.bigdata.bop.IVariable;
import com.bigdata.rdf.sparql.ast.ASTUtil;
import com.bigdata.rdf.sparql.ast.IQueryNode;
import com.bigdata.rdf.sparql.ast.NamedSubqueriesNode;
import com.bigdata.rdf.sparql.ast.NamedSubqueryInclude;
import com.bigdata.rdf.sparql.ast.NamedSubqueryRoot;
import com.bigdata.rdf.sparql.ast.QueryNodeWithBindingSet;
import com.bigdata.rdf.sparql.ast.QueryRoot;
import com.bigdata.rdf.sparql.ast.StaticAnalysis;
import com.bigdata.rdf.sparql.ast.SubqueryBase;
import com.bigdata.rdf.sparql.ast.SubqueryRoot;
import com.bigdata.rdf.sparql.ast.VarNode;
import com.bigdata.rdf.sparql.ast.eval.AST2BOpContext;

import cutthecrap.utils.striterators.Striterator;

/**
 * Class identifies the join variables for each instance in which a named
 * subquery solution set is incorporated into the query plan.
 *
 * @see NamedSubqueryRoot
 * @see NamedSubqueryInclude
 *
 * @author <a href="mailto:thompsonbry@users.sourceforge.net">Bryan Thompson</a>
 * @version $Id$
 */
public class ASTNamedSubqueryOptimizer implements IASTOptimizer {

//    private static final Logger log = Logger
//            .getLogger(ASTNamedSubqueryOptimizer.class);
    
    /**
     *
     * @throws RuntimeException
     *             if there is an {@link NamedSubqueryInclude} for a named
     *             solution set which is not generated by the query.
     * @throws RuntimeException
     *             if there is an {@link NamedSubqueryRoot} for a named solution
     *             set which is not consumed by the query.
     * @throws RuntimeException
     *             if there is more than one {@link NamedSubqueryRoot} for a
     *             given named solution set.
     */
    @Override
    public QueryNodeWithBindingSet optimize(
        final AST2BOpContext context, final QueryNodeWithBindingSet input) {

        final IQueryNode queryNode = input.getQueryNode();
        final IBindingSet[] bindingSet = input.getBindingSets();     

        final QueryRoot queryRoot = (QueryRoot) queryNode;

        final NamedSubqueriesNode namedSubqueries = queryRoot
                .getNamedSubqueries();

        if (namedSubqueries == null || namedSubqueries.isEmpty()) {

            // NOP.
           return new QueryNodeWithBindingSet(queryRoot, bindingSet);

        }

        /*
         * Order the named subqueries in order to support nested includes.
         * 
         * Note: The named subqueries must form an acyclic graph. They can
         * INCLUDE one another, but not in patterns which form cycles. This puts
         * them into an evaluation order.
         */
        orderNamedSubqueries(queryRoot, namedSubqueries);

        // The set of all INCLUDEs in the query.
        final NamedSubqueryInclude[] allIncludes = findAllIncludes(queryRoot);

        // Verify that a named subquery or solution set exists for each INCLUDE.
        assertNamedSubqueryForEachInclude(context, namedSubqueries, allIncludes);

        /*
         * Verify that each named subquery is consumed by at least one include
         * somewhere in the WHERE clause of the query.
         */
        assertEachNamedSubqueryIsUsed(namedSubqueries, allIncludes);

        /*
         * Figure out the join variables for each INCLUDE.
         */
        assignJoinVars(queryRoot, context, namedSubqueries, allIncludes);

        return new QueryNodeWithBindingSet(queryRoot, bindingSet);

    }

    /**
     * Return all {@link NamedSubqueryInclude}s which appear in the WHERE clause
     * of the main query.
     */
    static private NamedSubqueryInclude[] findAllIncludes(final QueryRoot queryRoot) {

        final Striterator itr = new Striterator(
                BOpUtility.postOrderIterator((BOp) queryRoot.getWhereClause()));

        itr.addTypeFilter(NamedSubqueryInclude.class);

        final List<NamedSubqueryInclude> list = new LinkedList<NamedSubqueryInclude>();

        while (itr.hasNext()) {

            list.add((NamedSubqueryInclude) itr.next());

        }

        final Striterator itr2 = new Striterator(
                BOpUtility.postOrderIterator((BOp) queryRoot.getWhereClause()));

        itr2.addTypeFilter(SubqueryRoot.class);


        while (itr2.hasNext()) {

            list.addAll(findSubqueryIncludes((SubqueryRoot) itr2.next()));

        }

        if (queryRoot.getNamedSubqueries() != null) {

            for(NamedSubqueryRoot root:queryRoot.getNamedSubqueries()){

                list.addAll(findSubqueryIncludes(root));

            }

        }

        return list.toArray(new NamedSubqueryInclude[] {});

    }

    /**
     * TODO This seems to be inefficient. We do not need to proceed
     * {@link SubqueryBase} by {@link SubqueryBase}.
     * {@link BOpUtility#visitAll(BOp, Class)} can be used to locate all
     * INCLUDEs in the entire query and then we can build up whatever indices we
     * need in optimize() and use them elsewhere as required.
     */
    static private List<NamedSubqueryInclude> findSubqueryIncludes(final SubqueryBase queryRoot){
        
        final Striterator itr = new Striterator(
                BOpUtility.postOrderIterator((BOp) queryRoot.getWhereClause()));

        itr.addTypeFilter(NamedSubqueryInclude.class);

        final List<NamedSubqueryInclude> list = new LinkedList<NamedSubqueryInclude>();

        while (itr.hasNext()) {

            list.add((NamedSubqueryInclude) itr.next());

        }

        final Striterator itr2 = new Striterator(
                BOpUtility.postOrderIterator((BOp) queryRoot.getWhereClause()));

        itr2.addTypeFilter(SubqueryRoot.class);


        while (itr2.hasNext()) {

            list.addAll(findSubqueryIncludes((SubqueryRoot) itr2.next()));

        }

        return list;

    }

    /**
     * Verify that a named subquery of solution set exists for each INCLUDE.
     *
     * @param context For querying solution sets
     * @param namedSubqueries
     * @param allIncludes
     */
    static private void assertNamedSubqueryForEachInclude(
    		final AST2BOpContext context,
            final NamedSubqueriesNode namedSubqueries,
            final NamedSubqueryInclude[] allIncludes) {

        for (NamedSubqueryInclude anInclude : allIncludes) {

            final String namedSet = anInclude.getName();

            if (namedSet == null || namedSet.trim().length() == 0)
                throw new RuntimeException(
                        "Missing or illegal name for include.");

            boolean found = false;

            for (NamedSubqueryRoot aNamedSubquery : namedSubqueries) {

                if (aNamedSubquery.getName().equals(namedSet)) {
                    found = true;
                    break;
                }

            }

            if (!found) {
            	try {
            	    context.getSolutionSetStats(namedSet);
            	    // There is a named solution set so we are OK.
            	}
            	catch (RuntimeException e) {
	                throw new RuntimeException(
	                        "No subquery produces the solution set: " + namedSet, e);
            	}
            }

        }

    }

    /**
     * Verify that each named subquery is consumed by at least one include
     * somewhere in the WHERE clause of the query.
     *
     * @param namedSubqueries
     * @param allIncludes
     */
    static private void assertEachNamedSubqueryIsUsed(
            final NamedSubqueriesNode namedSubqueries,
            final NamedSubqueryInclude[] allIncludes) {

        // The set of all named solution sets produced by this query.
        final Set<String> namedSets = new LinkedHashSet<String>();

        for (NamedSubqueryRoot aNamedSubquery : namedSubqueries) {

            final String namedSet = aNamedSubquery.getName();

            if (!namedSets.add(namedSet)) {

                throw new RuntimeException("NamedSet declared more than once: "
                        + namedSet);

            }

            if (namedSet == null || namedSet.trim().length() == 0)
                throw new RuntimeException(
                        "Missing or illegal name for named subquery.");

            final List<NamedSubqueryInclude> includes = new LinkedList<NamedSubqueryInclude>();

            for (NamedSubqueryInclude anInclude : allIncludes) {

                if (namedSet.equals(anInclude.getName())) {

                    includes.add(anInclude);

                }

            }

            if (includes.isEmpty()) {
                throw new RuntimeException(
                        "Named subquery results are not used by this query: "
                                + namedSet);
            }

        }

    }

    /**
     * Figure out the join variables for each INCLUDE. If the join variables
     * were already assigned to a {@link NamedSubqueryInclude}, then we just
     * make sure that the {@link NamedSubqueryRoot} will produce a suitable hash
     * index. If an INCLUDE does not have its join variables pre-assigned, then
     * we do a static analysis of the query and figure out which shared
     * variables MUST be bound. The set of shared variables is assigned as the
     * join variables. Again, we verify that a suitable hash index will be
     * produced for that INCLUDE.
     * <p>
     * Note: If the join variables were not pre-assigned (by a query hint) and
     * no join variables are identified by a static analysis then a full N x M
     * cross product of the solutions must be tested and filtered for those
     * solutions which join. This is a lot of effort when compared with a hash
     * join. Having the right join variables is very important for performance.
     * 
     * @param namedSubqueries
     * @param allIncludes
     * 
     * @see https://sourceforge.net/apps/trac/bigdata/ticket/405
     */
    static private void assignJoinVars(//
            final QueryRoot queryRoot,//
            final AST2BOpContext context,//
            final NamedSubqueriesNode namedSubqueries,//
            final NamedSubqueryInclude[] allIncludes) {

        final StaticAnalysis sa = new StaticAnalysis(queryRoot, context);

        for (NamedSubqueryRoot aNamedSubquery : namedSubqueries) {

            final String namedSet = aNamedSubquery.getName();

            // Collect each INCLUDE for this named subquery.
            final List<NamedSubqueryInclude> includes = new LinkedList<NamedSubqueryInclude>();
            {

                for (NamedSubqueryInclude anInclude : allIncludes) {

                    if (namedSet.equals(anInclude.getName())) {

                        includes.add(anInclude);

                    }

                }

            }

            /*
             * Collect each distinct joinvar[] combination for those includes.
             *
             * Note: Since having the distinct joinvar[] combinations is
             * important, we sort each joinvar[] to ensure that they have a
             * common order.
             */
            final Set<JoinVars> distinctJoinVarsSet = new LinkedHashSet<JoinVars>();

            for (NamedSubqueryInclude anInclude : includes) {

                @SuppressWarnings("rawtypes")
                final IVariable[] joinvars;

                if (anInclude.getJoinVars() == null) {

                    /**
                     * Since no query hint was used, then figure out the join
                     * variables using a static analysis of the query.
                     * 
                     * Note: Since the named subqueries run with only the
                     * exogenous bindings as input, anything which is
                     * exogenously bound plus anything which is known bound can
                     * serve as a join variable.
                     * 
                     * TODO There is a StaticAnalysis bug - it fails to consider
                     * the exogenous bindings when computing the definitely
                     * bound variables.
                     * 
                     * @see <a
                     *      href="https://sourceforge.net/apps/trac/bigdata/ticket/412">
                     *      getDefinatelyBound() ignores exogenous variables
                     *      </a>
                     * 
                     *      TODO Optimize case where there are no exogenous
                     *      bindings such that the sole source solution for the
                     *      named subquery is an empty solution set.
                     * 
                     * @see <a
                     *      href="http://sourceforge.net/apps/trac/bigdata/ticket/535">
                     *      Optimize JOIN VARS for Sub-Selects </a>
                     */

                    final Set<IVariable<?>> set = new LinkedHashSet<IVariable<?>>();
                    
                    sa.getJoinVars(aNamedSubquery, anInclude, set);

                    joinvars = set.toArray(new IVariable[set.size()]);

                    // Sort.
                    Arrays.sort(joinvars);

                    // Set those join variables on the include.
                    anInclude.setJoinVars(ASTUtil.convert(joinvars));

                } else {

                    // Get the user specified join variables.
                    joinvars = ASTUtil.convert(anInclude.getJoinVars());

                    // Sort.
                    Arrays.sort(joinvars);

                    // Set them back on the include in sorted order.
                    anInclude.setJoinVars(ASTUtil.convert(joinvars));

                }

                distinctJoinVarsSet.add(new JoinVars(joinvars));

            }

            /*
             * Figure out the join variables for each place in the query where
             * the named result set is included and annotate the include
             * operator to specify the join variables for that include.
             */

            final int nhashIndices = distinctJoinVarsSet.size();

            if (nhashIndices > 1) {

                /*
                 * Since there is more than one set of join variables required
                 * by the INCLUDEs, we use the largest subset of the join
                 * variables defined across all of the includes.
                 */

                // First, collect all join variables.
                final Set<IVariable<?>> sharedVariables = new LinkedHashSet<IVariable<?>>();

                for (JoinVars joinVars : distinctJoinVarsSet) {

                    sharedVariables.addAll(joinVars.vars());

                }
                
                // Now, retain only those variables in scope for each include.
                for (JoinVars joinVars : distinctJoinVarsSet) {

                    sharedVariables.retainAll(joinVars.vars());

                }

                /*
                 * The join variables which are shared across all contexts in
                 * which this named solution set is joined back into the query.
                 */
                final VarNode[] sharedJoinVars = ASTUtil
                        .convert(sharedVariables.toArray(new IVariable[] {}));

                // Set the shared join variables on the named subquery.
                aNamedSubquery.setJoinVars(sharedJoinVars);

                for (NamedSubqueryInclude anInclude : includes) {

                    // Set the shared join variables on each subquery include.
                    anInclude.setJoinVars(sharedJoinVars);

                }

            } else {

                /*
                 * Since there is just one set of join variables we will use
                 * that.
                 */

                final JoinVars joinVars = distinctJoinVarsSet.iterator().next();

                aNamedSubquery.setJoinVars(ASTUtil.convert(joinVars.toArray()));

            }

        }

    }

    /**
     * Order the named subqueries based on nested includes.
     * 
     * TODO This should reuse the same arrays/collections that are generated for
     * the other logic in this class. No need to repeatedly traverse the query
     * looking for INCLUDEs.
     * 
     * TODO This should use some generic topological sort algorithm. e.g. it is not obvious that
     * this code covers the case where two named subqueries include each other.
     */
    static private void orderNamedSubqueries(final QueryRoot queryRoot,
            final NamedSubqueriesNode namedSubqueries) {
    
        // Map from solution set name to named subquery root.
        final Map<String, NamedSubqueryRoot> nameToSubquery = new LinkedHashMap<String, NamedSubqueryRoot>();
        {
        
            for (NamedSubqueryRoot aNamedSubquery : namedSubqueries) {

                nameToSubquery.put(aNamedSubquery.getName(), aNamedSubquery);

            }

        }

        /*
         * Map from named subquery root to the list of named solutions on which
         * each named subquery depends. Those named solutions must be computed
         * before any named subquery root which will consume them.
         */
        final Map<NamedSubqueryRoot, List<String>> subqueryToIncludes = new LinkedHashMap<NamedSubqueryRoot, List<String>>();
        {
            
            for (NamedSubqueryRoot aNamedSubquery : namedSubqueries) {

                final List<String> includes = new LinkedList<String>();
                final List<String> includesNamedSubqueries = new LinkedList<String>();

                subqueryToIncludes.put(aNamedSubquery, includesNamedSubqueries);

                for (NamedSubqueryInclude include : findSubqueryIncludes(aNamedSubquery)) {
                	
                	String name = include.getName();
                	includes.add(name);
                	
                	if ( nameToSubquery.containsKey(name) ) {

                		includesNamedSubqueries.add(name);
                        
                	} // else name gives a named solution set.

                }

                // Set the DEPENDS_ON annotation: named subqueries and solution sets
                aNamedSubquery.setDependsOn(includes.toArray(new String[0]));

            }
            
        }
        
        /*
         * Create a new NamedSubqueriesNode which corresponds to a valid
         * evaluation order for the named subqueries.
         */
        {

            final Set<String> processed = new HashSet<String>();

            final NamedSubqueriesNode newNode = new NamedSubqueriesNode();

            Iterator<Map.Entry<NamedSubqueryRoot, List<String>>> iter = subqueryToIncludes
                    .entrySet().iterator();

            while (iter.hasNext()) {
                final Map.Entry<NamedSubqueryRoot, List<String>> entry = iter
                        .next();
                final NamedSubqueryRoot namedSubquery = entry.getKey();
                if (entry.getValue().size() == 0) {
                    newNode.add(namedSubquery);
                    processed.add(namedSubquery.getName());
                    iter.remove();
                }
            }

            while (subqueryToIncludes.size() > 0) {
                iter = subqueryToIncludes.entrySet().iterator();
                while (iter.hasNext()) {
                    boolean ok = true;
                    final Map.Entry<NamedSubqueryRoot, List<String>> entry = iter
                            .next();
                    for (String dep : entry.getValue()) {
                        if (!processed.contains(dep)) {
                            ok = false;
                            break;
                        }
                    }
                    if (ok) {
                        newNode.add(entry.getKey());
                        processed.add(entry.getKey().getName());
                        iter.remove();
                    }
                }
            }

            // Update the QueryRoot with the named subquery evaluation order.
            queryRoot.setNamedSubqueries(newNode);

        }

    }

    /**
     * Wrapper class used to inflict Arrays.equals() rather than Object.equals()
     * when an array is used in a Collection.
     */
    private static class JoinVars {

        private final Set<IVariable<?>> vars;

        private final int hashCode;

        public Set<IVariable<?>> vars() {
            
            return Collections.unmodifiableSet(vars);
            
        }
        
        public IVariable<?>[] toArray() {
            
            return vars.toArray(new IVariable[vars.size()]);
            
        }
        
        public JoinVars(final IVariable<?>[] vars) {

            this.vars = new LinkedHashSet<IVariable<?>>();

            for (int i = 0; i < vars.length; i++) {

                this.vars.add(vars[i]);

            }

            this.hashCode = Arrays.hashCode(vars);

        }

        @Override
        public int hashCode() {
            return hashCode;
        }

        @Override
        public boolean equals(final Object o) {
            if (this == o)
                return true;
            if (!(o instanceof JoinVars))
                return false;
            final JoinVars t = (JoinVars) o;
            return vars.equals(t.vars);
//            return Arrays.equals(vars, t.vars);
        }

    }

}
